package train

import (
	"fmt"
	"iter"

	"github.com/gomlx/gomlx/internal/exceptions"
	"github.com/gomlx/gomlx/pkg/core/graph"
	"github.com/gomlx/gomlx/pkg/core/tensors"
	"github.com/gomlx/gomlx/pkg/ml/context"
	"github.com/gomlx/gomlx/pkg/ml/context/initializers"
	"github.com/gomlx/gopjrt/dtypes"
	"github.com/pkg/errors"
)

// AccumulatedGradientsScope is the scope under which we accumulated gradients.
const AccumulatedGradientsScope = "acc_grads"

// OptimizeWithGradients is the interface an optimizer needs to support to be used with AccumulateGradients.
type OptimizeWithGradients interface {
	// UpdateGraphWithGradients works like Optimizers.Interface.UpdateGraph, but takes as input the gradients themselves,
	// as opposed to the loss.
	UpdateGraphWithGradients(ctx *context.Context, gradients []*graph.Node, lossDType dtypes.DType)
}

// AccumulateGradients configures the trainer to accumulate numAccumulatingSteps of gradients before actually applying them.
// globalStep will only be updated after numAccumulatingSteps are fed with Trainer.TrainStep.
//
// Notice that setting this makes the concept of "TrainStep" and "GlobalStep" diverge: there will now be numAccumulatingSteps "train steps"
// per "global step".
func (r *Trainer) AccumulateGradients(numAccumulatingSteps int) error {
	if r.optimizer == nil {
		return errors.New("optimizer is nil!?")
	}
	if _, ok := r.optimizer.(OptimizeWithGradients); !ok {
		return errors.Errorf("optimizer %T does not implement OptimizeWithGradients -- to use AccumulateGradients use an optmizer that supports it (e.g.: SGD, Adam)", r.optimizer)
	}
	r.accumulateGradients = true
	r.accumulateGradientsSteps = numAccumulatingSteps
	r.accumulateGradientsCurrentStep = 0
	return nil
}

// NumAccumulatingSteps return the number of accumulating steps, if AccumulateGradients is being used.
// Otherwise, it returns 0.
func (r *Trainer) NumAccumulatingSteps() int {
	return r.accumulateGradientsSteps
}

// iterTrainableAndAccumulatorVariables iterates over all the trainable variables (in the current graph g) and yields
// both the trainable variable and the corresponding gradient accumulator variable.
func iterTrainableAndAccumulatorVariables(ctx *context.Context, g *graph.Graph) iter.Seq2[*context.Variable, *context.Variable] {
	return func(yield func(trainable, accumulator *context.Variable) bool) {
		for v := range ctx.IterVariables() {
			if !v.Trainable || !v.InUseByGraph(g) {
				// We are only interested in trainable variables used by this graph.
				continue
			}
			originalScope := v.Scope()
			originalName := v.Name()
			scopePath := fmt.Sprintf("%s%s%s", context.ScopeSeparator, AccumulatedGradientsScope, originalScope)
			shape := v.Shape().Clone()
			ctx = ctx.Checked(false) // It shouldn't matter if it's the first time or not creating the variable.
			accumulator := ctx.InAbsPath(scopePath).
				WithInitializer(initializers.Zero).
				VariableWithShape(originalName, shape).
				SetTrainable(false)
			if !yield(v, accumulator) {
				return
			}
		}
	}
}

// accumulateStepGraphImpl implements the computation graph to generate and accumulate one step.
// It is called by the context executor (`r.accumulateGradientsExecMap`) everytime a graph needs to be built:
// the first time, when the batch size changes, or the dataset spec changes.
//
// If applyGradients is true, it will also apply the accumulated gradients.
func (r *Trainer) accumulateStepGraphImpl(spec any, ctx *context.Context, inputs, labels []*graph.Node, applyGradients bool) (metrics []*graph.Node) {
	g := inputs[0].Graph()
	ctx.SetTraining(g, true) // Some layers behave differently if in training.
	if applyGradients {
		// Generally, variables should already have been created at this point.
		ctx = ctx.Checked(false)
	}

	// AddLoss generated by the given lossFn.
	predictions := r.modelFn(ctx, spec, inputs)
	if r.lossFn != nil {
		baseLoss := r.lossFnScalarLoss(ctx, labels, predictions)
		SetLossNoRegularization(ctx, baseLoss)
		AddLoss(ctx, baseLoss)
	}

	// Store total loss as a variable, so it can be used by metrics.
	loss := GetLosses(ctx, g)
	if loss == nil {
		exceptions.Panicf("no loss function defined (or it returned nil), and no loss set with AddLoss(), " +
			"there is nothing to optimize!?")
	}

	// Metrics updates. They include: batch loss, exponential moving average of the batch loss.
	if len(predictions) == 0 {
		// We create a zero prediction (same dtype as loss), because the metrics require something.
		predictions = []*graph.Node{graph.ScalarZero(g, loss.DType())}
	}
	metrics = r.metricsUpdatesGraph(ctx, labels, predictions, r.trainMetrics)

	// Calculate and accumulate gradients: grads is a slice of gradients for each trainable variable, in
	// the same order as the variables are iterated.
	grads := ctx.BuildTrainableVariablesGradientsGraph(loss)
	numTrainable := len(grads)
	varIdx := 0
	for v, accVar := range iterTrainableAndAccumulatorVariables(ctx, g) {
		if varIdx > numTrainable {
			exceptions.Panicf("more gradients (%d) than trainable variables, this should not happen!", len(grads))
		}
		shape := v.Shape()
		if !shape.Equal(grads[varIdx].Shape()) {
			exceptions.Panicf(
				"shape mismatch between variable %q (shape=%s) and the correspoding gradient #%d (shape=%s)",
				v.ScopeAndName(), v.Shape(), varIdx, grads[varIdx].Shape())
		}
		// Store updated variable, and update gradient to accumulated value.
		grads[varIdx] = graph.Add(accVar.ValueGraph(g), grads[varIdx])
		accVar.SetValueGraph(grads[varIdx])
		varIdx++
	}
	if !applyGradients {
		// Not yet time to apply the gradients, return.
		return
	}

	// Take the mean of the accumulated gradients.
	ratio := 1.0 / float64(r.accumulateGradientsSteps)
	for ii := range grads {
		grads[ii] = graph.MulScalar(grads[ii], ratio)
	}

	// Apply mean of accumulated gradients with optimizer:
	lossDType := loss.DType()
	opt := r.optimizer.(OptimizeWithGradients)
	opt.UpdateGraphWithGradients(ctx, grads, lossDType)

	// Execute registered ContextGraphFn hooks for the current graph.
	ExecPerStepUpdateGraphFn(ctx, g)

	// Reset accumulated gradients.
	for _, accVar := range iterTrainableAndAccumulatorVariables(ctx, g) {
		accVar.SetValueGraph(graph.Zeros(g, accVar.Shape()))
	}
	return
}

func (r *Trainer) accumulateStepNoApplyGraph(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node) {
	return r.accumulateStepGraphImpl(spec, ctx, inputs, labels, false)
}

func (r *Trainer) accumulateStepAndApplyGraph(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node) {
	return r.accumulateStepGraphImpl(spec, ctx, inputs, labels, true)
}

func (r *Trainer) trainStepWithAccumulateGradients(spec any, inputs, labels []*tensors.Tensor) (metrics []*tensors.Tensor, err error) {
	r.accumulateGradientsCurrentStep++
	if r.accumulateGradientsCurrentStep < r.accumulateGradientsSteps {
		return r.callGraphFn(r.accumulateStepNoApplyGraph, TrainType, r.accumulateGradientsExecMap, spec, inputs, labels)
	}
	r.accumulateGradientsCurrentStep = 0
	return r.callGraphFn(r.accumulateStepAndApplyGraph, TrainType, r.accumulateGradientsAndApplyExecMap, spec, inputs, labels)
}
